# coding: utf8
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function
import sys
import os
import time
import codecs

import numpy as np
import cv2

from utils.config import cfg
import data_aug as aug
from pdseg.data_utils import GeneratorEnqueuer
from models.model_builder import ModelPhase
import copy


def cv2_imread(file_path, flag=cv2.IMREAD_COLOR):
    # resolve cv2.imread open Chinese file path issues on Windows Platform.
    return cv2.imdecode(np.fromfile(file_path, dtype=np.uint8), flag)


class LaneNetDataset():
    def __init__(self,
                 file_list,
                 data_dir,
                 shuffle=False,
                 mode=ModelPhase.TRAIN):
        self.mode = mode
        self.shuffle = shuffle
        self.data_dir = data_dir

        self.shuffle_seed = 0
        # NOTE: Please ensure file list was save in UTF-8 coding format
        with codecs.open(file_list, 'r', 'utf-8') as flist:
            self.lines = [line.strip() for line in flist]
            self.all_lines = copy.deepcopy(self.lines)
            if shuffle and cfg.NUM_TRAINERS > 1:
                np.random.RandomState(self.shuffle_seed).shuffle(self.all_lines)
            elif shuffle:
                np.random.shuffle(self.lines)

    def generator(self):
        if self.shuffle and cfg.NUM_TRAINERS > 1:
            np.random.RandomState(self.shuffle_seed).shuffle(self.all_lines)
            num_lines = len(self.all_lines) // cfg.NUM_TRAINERS
            self.lines = self.all_lines[num_lines * cfg.TRAINER_ID:num_lines *
                                        (cfg.TRAINER_ID + 1)]
            self.shuffle_seed += 1
        elif self.shuffle:
            np.random.shuffle(self.lines)

        for line in self.lines:
            yield self.process_image(line, self.data_dir, self.mode)

    def sharding_generator(self, pid=0, num_processes=1):
        """
        Use line id as shard key for multiprocess io
        It's a normal generator if pid=0, num_processes=1
        """
        for index, line in enumerate(self.lines):
            # Use index and pid to shard file list
            if index % num_processes == pid:
                yield self.process_image(line, self.data_dir, self.mode)

    def batch_reader(self, batch_size):
        br = self.batch(self.reader, batch_size)
        for batch in br:
            yield batch[0], batch[1], batch[2]

    def multiprocess_generator(self, max_queue_size=32, num_processes=8):
        # Re-shuffle file list
        if self.shuffle and cfg.NUM_TRAINERS > 1:
            np.random.RandomState(self.shuffle_seed).shuffle(self.all_lines)
            num_lines = len(self.all_lines) // self.num_trainers
            self.lines = self.all_lines[num_lines * self.trainer_id:num_lines *
                                        (self.trainer_id + 1)]
            self.shuffle_seed += 1
        elif self.shuffle:
            np.random.shuffle(self.lines)

        # Create multiple sharding generators according to num_processes for multiple processes
        generators = []
        for pid in range(num_processes):
            generators.append(self.sharding_generator(pid, num_processes))

        try:
            enqueuer = GeneratorEnqueuer(generators)
            enqueuer.start(max_queue_size=max_queue_size, workers=num_processes)
            while True:
                generator_out = None
                while enqueuer.is_running():
                    if not enqueuer.queue.empty():
                        generator_out = enqueuer.queue.get(timeout=5)
                        break
                    else:
                        time.sleep(0.01)
                if generator_out is None:
                    break
                yield generator_out
        finally:
            if enqueuer is not None:
                enqueuer.stop()

    def batch(self, reader, batch_size, is_test=False, drop_last=False):
        def batch_reader(is_test=False, drop_last=drop_last):
            if is_test:
                imgs, grts, grts_instance, img_names, valid_shapes, org_shapes = [], [], [], [], [], []
                for img, grt, grt_instance, img_name, valid_shape, org_shape in reader(
                ):
                    imgs.append(img)
                    grts.append(grt)
                    grts_instance.append(grt_instance)
                    img_names.append(img_name)
                    valid_shapes.append(valid_shape)
                    org_shapes.append(org_shape)
                    if len(imgs) == batch_size:
                        yield np.array(imgs), np.array(grts), np.array(
                            grts_instance), img_names, np.array(
                                valid_shapes), np.array(org_shapes)
                        imgs, grts, grts_instance, img_names, valid_shapes, org_shapes = [], [], [], [], [], []

                if not drop_last and len(imgs) > 0:
                    yield np.array(imgs), np.array(grts), np.array(
                        grts_instance), img_names, np.array(
                            valid_shapes), np.array(org_shapes)
            else:
                imgs, labs, labs_instance, ignore = [], [], [], []
                bs = 0
                for img, lab, lab_instance, ig in reader():
                    imgs.append(img)
                    labs.append(lab)
                    labs_instance.append(lab_instance)
                    ignore.append(ig)
                    bs += 1
                    if bs == batch_size:
                        yield np.array(imgs), np.array(labs), np.array(
                            labs_instance), np.array(ignore)
                        bs = 0
                        imgs, labs, labs_instance, ignore = [], [], [], []

                if not drop_last and bs > 0:
                    yield np.array(imgs), np.array(labs), np.array(
                        labs_instance), np.array(ignore)

        return batch_reader(is_test, drop_last)

    def load_image(self, line, src_dir, mode=ModelPhase.TRAIN):
        # original image cv2.imread flag setting
        cv2_imread_flag = cv2.IMREAD_COLOR
        if cfg.DATASET.IMAGE_TYPE == "rgba":
            # If use RBGA 4 channel ImageType, use IMREAD_UNCHANGED flags to
            # reserver alpha channel
            cv2_imread_flag = cv2.IMREAD_UNCHANGED

        parts = line.strip().split(cfg.DATASET.SEPARATOR)
        if len(parts) != 3:
            if mode == ModelPhase.TRAIN or mode == ModelPhase.EVAL:
                raise Exception("File list format incorrect! It should be"
                                " image_name{}label_name\\n".format(
                                    cfg.DATASET.SEPARATOR))
            img_name, grt_name, grt_instance_name = parts[0], None, None
        else:
            img_name, grt_name, grt_instance_name = parts[0], parts[1], parts[2]

        img_path = os.path.join(src_dir, img_name)
        img = cv2_imread(img_path, cv2_imread_flag)

        if grt_name is not None:
            grt_path = os.path.join(src_dir, grt_name)
            grt_instance_path = os.path.join(src_dir, grt_instance_name)
            grt = cv2_imread(grt_path, cv2.IMREAD_GRAYSCALE)
            grt[grt == 255] = 1
            grt[grt != 1] = 0
            grt_instance = cv2_imread(grt_instance_path, cv2.IMREAD_GRAYSCALE)
        else:
            grt = None
            grt_instance = None

        if img is None:
            raise Exception(
                "Empty image, src_dir: {}, img: {} & lab: {}".format(
                    src_dir, img_path, grt_path))

        img_height = img.shape[0]
        img_width = img.shape[1]

        if grt is not None:
            grt_height = grt.shape[0]
            grt_width = grt.shape[1]

            if img_height != grt_height or img_width != grt_width:
                raise Exception(
                    "source img and label img must has the same size")
        else:
            if mode == ModelPhase.TRAIN or mode == ModelPhase.EVAL:
                raise Exception(
                    "Empty image, src_dir: {}, img: {} & lab: {}".format(
                        src_dir, img_path, grt_path))

        if len(img.shape) < 3:
            img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)

        img_channels = img.shape[2]
        if img_channels < 3:
            raise Exception("PaddleSeg only supports gray, rgb or rgba image")
        if img_channels != cfg.DATASET.DATA_DIM:
            raise Exception(
                "Input image channel({}) is not match cfg.DATASET.DATA_DIM({}), img_name={}"
                .format(img_channels, cfg.DATASET.DATADIM, img_name))
        if img_channels != len(cfg.MEAN):
            raise Exception(
                "img name {}, img chns {} mean size {}, size unequal".format(
                    img_name, img_channels, len(cfg.MEAN)))
        if img_channels != len(cfg.STD):
            raise Exception(
                "img name {}, img chns {} std size {}, size unequal".format(
                    img_name, img_channels, len(cfg.STD)))

        return img, grt, grt_instance, img_name, grt_name

    def normalize_image(self, img):
        """ 像素归一化后减均值除方差 """
        img = img.transpose((2, 0, 1)).astype('float32') / 255.0
        img_mean = np.array(cfg.MEAN).reshape((len(cfg.MEAN), 1, 1))
        img_std = np.array(cfg.STD).reshape((len(cfg.STD), 1, 1))
        img -= img_mean
        img /= img_std

        return img

    def process_image(self, line, data_dir, mode):
        """ process_image """
        img, grt, grt_instance, img_name, grt_name = self.load_image(
            line, data_dir, mode=mode)
        if mode == ModelPhase.TRAIN:
            img, grt, grt_instance = aug.resize(img, grt, grt_instance, mode)
            if cfg.AUG.RICH_CROP.ENABLE:
                if cfg.AUG.RICH_CROP.BLUR:
                    if cfg.AUG.RICH_CROP.BLUR_RATIO <= 0:
                        n = 0
                    elif cfg.AUG.RICH_CROP.BLUR_RATIO >= 1:
                        n = 1
                    else:
                        n = int(1.0 / cfg.AUG.RICH_CROP.BLUR_RATIO)
                    if n > 0:
                        if np.random.randint(0, n) == 0:
                            radius = np.random.randint(3, 10)
                            if radius % 2 != 1:
                                radius = radius + 1
                            if radius > 9:
                                radius = 9
                            img = cv2.GaussianBlur(img, (radius, radius), 0, 0)

                img, grt = aug.random_rotation(
                    img,
                    grt,
                    rich_crop_max_rotation=cfg.AUG.RICH_CROP.MAX_ROTATION,
                    mean_value=cfg.DATASET.PADDING_VALUE)

                img, grt = aug.rand_scale_aspect(
                    img,
                    grt,
                    rich_crop_min_scale=cfg.AUG.RICH_CROP.MIN_AREA_RATIO,
                    rich_crop_aspect_ratio=cfg.AUG.RICH_CROP.ASPECT_RATIO)
                img = aug.hsv_color_jitter(
                    img,
                    brightness_jitter_ratio=cfg.AUG.RICH_CROP.
                    BRIGHTNESS_JITTER_RATIO,
                    saturation_jitter_ratio=cfg.AUG.RICH_CROP.
                    SATURATION_JITTER_RATIO,
                    contrast_jitter_ratio=cfg.AUG.RICH_CROP.
                    CONTRAST_JITTER_RATIO)

            if cfg.AUG.FLIP:
                if cfg.AUG.FLIP_RATIO <= 0:
                    n = 0
                elif cfg.AUG.FLIP_RATIO >= 1:
                    n = 1
                else:
                    n = int(1.0 / cfg.AUG.FLIP_RATIO)
                if n > 0:
                    if np.random.randint(0, n) == 0:
                        img = img[::-1, :, :]
                        grt = grt[::-1, :]

            if cfg.AUG.MIRROR:
                if np.random.randint(0, 2) == 1:
                    img = img[:, ::-1, :]
                    grt = grt[:, ::-1]

            img, grt = aug.rand_crop(img, grt, mode=mode)
        elif ModelPhase.is_eval(mode):
            img, grt, grt_instance = aug.resize(
                img, grt, grt_instance, mode=mode)
        elif ModelPhase.is_visual(mode):
            ori_img = img.copy()
            img, grt, grt_instance = aug.resize(
                img, grt, grt_instance, mode=mode)
            valid_shape = [img.shape[0], img.shape[1]]
        else:
            raise ValueError("Dataset mode={} Error!".format(mode))

        # Normalize image
        img = self.normalize_image(img)

        if ModelPhase.is_train(mode) or ModelPhase.is_eval(mode):
            grt = np.expand_dims(np.array(grt).astype('int32'), axis=0)
            ignore = (grt != cfg.DATASET.IGNORE_INDEX).astype('int32')
        if ModelPhase.is_train(mode):
            return (img, grt, grt_instance, ignore)
        elif ModelPhase.is_eval(mode):
            return (img, grt, grt_instance, ignore)
        elif ModelPhase.is_visual(mode):
            return (img, grt, grt_instance, img_name, valid_shape, ori_img)
